import ml_collections

def get_config():
    """Returns the configuration."""
    config = ml_collections.ConfigDict()
    config.emb_dim = 1024
    config.em_iter=3
    config.tau=0.01
    config.ot_eps=0.1
    config.heads=1
    config.out_type="allcat"
    config.load_proto=True
    config.proto_path="."
    config.fix_proto=False
    config.fix_em_proto= False

    #IndivMLPEmb_IndivPost
    config.in_dim=2049
    # config.in_dim = 1024
    # n_classes: 4,
    config.shared_embed_dim=256
    config.indiv_embed_dim=128
    config.postcat_embed_dim= 1024
    config.shared_mlp= False
    config.indiv_mlps= True
    config.postcat_mlp= True
    config.n_fc_layers= 1
    config.shared_dropout= 0.1
    config.indiv_dropout= 0.1
    config.postcat_dropout= 0.1

    config.feature_dim = 1024
    config.attn_dim = 768
    config.num_heads = 12
    config.attn_dropout_rate=0.0
    config.proj_dropout_rate=0.2
    config.transformer = ml_collections.ConfigDict()
    config.transformer.mlp_dim = 1024 * 2
    config.transformer.dropout_rate = 0.0
    return config